/*******************************************************************************
 * Package: com.song.sql.entity
 * Type:    CityDetailAgg
 * Date:    2024-12-25 16:55
 *
 * Copyright (c) 2024 LTD All Rights Reserved.
 *
 * You may not use this file except in compliance with the License.
 *******************************************************************************/
package com.song.sql.entity;

import com.song.sql.teacher.CityCount;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.Encoder;

import java.util.*;
import java.util.stream.Collectors;


/**
 * 功能描述：
 *
 * @author Songxianyang
 * @date 2024-12-25 16:55
 */
public class CityDetailAgg extends Aggregator<String, CityDetailBuff, String> {
    private static final long serialVersionUID = 5045933557715981550L;

    // 缓冲区初始化
    @Override
    public CityDetailBuff zero() {
        return new CityDetailBuff(0L, new HashMap<>());
    }

    // 聚合
    @Override
    public CityDetailBuff reduce(CityDetailBuff b, String cityName) {
        b.setTotalNumber(b.getTotalNumber() + 1);
        Map<String, Long> map = b.getMap();
        if (Objects.isNull(map.get(cityName))) {
            map.put(cityName, 0L);
        } else {
            map.put(cityName, map.get(cityName) + 1L);
        }
        // 赋值
        b.setMap(map);
        return b;
    }

    // 分布式环境数据合并
    @Override
    public CityDetailBuff merge(CityDetailBuff b1, CityDetailBuff b2) {
        b1.setTotalNumber(b1.getTotalNumber() + b2.getTotalNumber());

        Map<String, Long> b1Map = b1.getMap();
        Map<String, Long> b2Map = b2.getMap();

        Map<String, Long> allMap = new HashMap<>();
        allMap.putAll(b1Map);
        allMap.putAll(b2Map);

        // 2不变 1遍历
        for (Map.Entry<String, Long> entry : b1Map.entrySet()) {
            if (Objects.nonNull(b2Map.get(entry.getKey()))) {
                allMap.put(entry.getKey(), entry.getValue() + b2Map.get(entry.getKey()));
            }
        }

        b1.setMap(allMap);
        return b1;
    }

    // 响应的结果
    @Override
    public String finish(CityDetailBuff reduction) {
        StringBuilder detail = new StringBuilder();

        Long totalNumber = reduction.getTotalNumber();
        Map<String, Long> map = reduction.getMap();
        List<CityCount> cityCounts = new ArrayList<>();
        for (Map.Entry<String, Long> entry : map.entrySet()) {
            cityCounts.add(new CityCount(entry.getKey(), entry.getValue()));
        }

        // 降序
        List<CityCount> newCityCount = cityCounts.stream().sorted(Comparator.comparing(CityCount::getCount).reversed()).collect(Collectors.toList());

        CityCount cityCount0 = newCityCount.get(0);
        Long pc0 = cityCount0.getCount() * 100 / totalNumber; // 10 * 100/20 => 50
        detail.append(cityCount0.getCityName() + " "+ pc0 +"%");

        CityCount cityCount1 = newCityCount.get(1);
        Long pc1 = cityCount1.getCount() * 100 / totalNumber; // 10 * 100/20 => 50
        detail.append(cityCount1.getCityName() + " "+ pc1 +"%");

        // CityCount cityCount2 = newCityCount.get(2);
        // double pc2 = cityCount2.getCount() * 100 / totalNumber; // 10 * 100/20 => 50
        // detail.append(cityCount2.getCityName() + " "+ pc2 +"%");

        if ( cityCounts.size() > 2 ) {
            detail.append("其他 "+(100 - pc0 - pc1)+"%");
        }

        return detail.toString();
    }

    // 缓冲区是谁
    @Override
    public Encoder<CityDetailBuff> bufferEncoder() {
        return Encoders.bean(CityDetailBuff.class);
    }

    // 输出的数据结构
    @Override
    public Encoder<String> outputEncoder() {
        return Encoders.STRING();
    }
}
